#!/usr/bin/env python3
import argparse, math
import numpy as np
import pandas as pd
from astropy.io import fits

def edges_from_csv(s: str):
    return [float(x) for x in s.split(",")]

def midpoints(edges):
    e = np.array(edges, dtype=float)
    return 0.5 * (e[:-1] + e[1:])

def assign_bin(x, edges):
    i = np.digitize([float(x)], edges, right=False)[0] - 1
    return i if 0 <= i < len(edges) - 1 else None

def bin_label(edges, i):
    return f"{edges[i]}–{edges[i+1]}"

def make_colmap(tbl):
    names = list(tbl.columns.names) if hasattr(tbl.columns, "names") else [c.name for c in tbl.columns]
    return {n.lower(): n for n in names}

def pick(colmap, *cands):
    for c in cands:
        n = colmap.get(c.lower())
        if n is not None:
            return n
    raise KeyError(f"Missing expected column. Tried {cands}, have {list(colmap.values())}")

def main():
    ap = argparse.ArgumentParser(description="KiDS -> prestacked_stacks.csv (tangential shear vs b, grouped by size/mass bins)")
    ap.add_argument("--kids", required=True, help="KiDS WL FITS (e.g. KiDS_DR4.1_..._SOM_gold_WL_cat.fits)")
    ap.add_argument("--lenses", required=True, help="CSV with lens_id,ra_deg,dec_deg,z_lens,RG_kpc,Mstar_log10")
    ap.add_argument("--out", default="data/prestacked_stacks.csv")
    ap.add_argument("--out-meta", default="data/prestacked_meta.csv")
    ap.add_argument("--rg-bins", default="5,7.5,10,12.5,15", help="kpc edges (comma-separated)")
    ap.add_argument("--mstar-bins", default="10.2,10.5,10.8,11.1", help="log10(M*) edges (comma-separated)")
    ap.add_argument("--b-bins-arcsec",
                    default="10,15,22,32,46,66,95,137,198,285,410,592,855,1236,1787,2583",
                    help="separation edges (arcsec, comma-separated)")
    ap.add_argument("--min-zsep", type=float, default=0.1, help="require z_source >= z_lens + min_zsep (if photo-z present)")
    ap.add_argument("--use-m-corr", action="store_true", help="apply 1/(1+m) if an m column exists")
    ap.add_argument("--max-lenses", type=int, default=None, help="limit lenses for a quick smoke test")
    args = ap.parse_args()

    # Bin definitions
    rg_edges = edges_from_csv(args.rg_bins)
    ms_edges = edges_from_csv(args.mstar_bins)
    b_edges = np.array(edges_from_csv(args.b_bins_arcsec), dtype=float)
    b_mids = midpoints(b_edges)
    nb = len(b_mids)

    # Lenses
    L = pd.read_csv(args.lenses)
    need = ["ra_deg", "dec_deg", "z_lens", "RG_kpc", "Mstar_log10"]
    for c in need:
        if c not in L.columns:
            raise SystemExit(f"Missing column '{c}' in {args.lenses}")
    L = L.dropna(subset=need)
    if args.max_lenses:
        L = L.head(args.max_lenses)

    # KiDS catalog
    hdul = fits.open(args.kids, memmap=True)
    tbl = hdul[1]
    data = tbl.data
    colmap = make_colmap(tbl)

    ra_c  = pick(colmap, "RAJ2000", "ALPHA_J2000", "RA", "ALPHAWIN_J2000")
    de_c  = pick(colmap, "DECJ2000", "DELTA_J2000", "DEC", "DELTAWIN_J2000")
    e1_c  = pick(colmap, "e1", "ELLIP1")
    e2_c  = pick(colmap, "e2", "ELLIP2")
    w_c   = pick(colmap, "weight", "W")

    z_c   = None
    for zcand in ("z_B", "Z_B", "PHOTOZ", "ZPHOT"):
        if zcand.lower() in colmap:
            z_c = colmap[zcand.lower()]
            break

    m_c   = None
    for mcand in ("m", "MCOR", "m_corr", "M"):
        if mcand.lower() in colmap:
            m_c = colmap[mcand.lower()]
            break

    src_ra = data[ra_c].astype(float)
    src_de = data[de_c].astype(float)
    e1 = data[e1_c].astype(float)
    e2 = data[e2_c].astype(float)
    w  = data[w_c].astype(float)
    z  = data[z_c].astype(float) if z_c else None
    m  = data[m_c].astype(float) if m_c else None

    # Accumulators per (size, mass) stack
    stacks = {}  # (rg_label, ms_label) -> {"sum_we_t":..., "sum_w":..., "nL":int}

    for _, l in L.iterrows():
        rgi = assign_bin(l["RG_kpc"], rg_edges)
        msi = assign_bin(l["Mstar_log10"], ms_edges)
        if rgi is None or msi is None:
            continue
        rg_lab = bin_label(rg_edges, rgi)
        ms_lab = bin_label(ms_edges, msi)
        key = (rg_lab, ms_lab)

        # Spatial pre-cut (square around lens big enough to cover max b)
        max_deg = float(b_edges[-1]) / 3600.0
        cosd = math.cos(math.radians(float(l["dec_deg"])))
        dra = (src_ra - float(l["ra_deg"])) * cosd
        dde = (src_de - float(l["dec_deg"]))
        sel0 = (np.abs(dra) < max_deg) & (np.abs(dde) < max_deg)

        # Redshift cut (if photo-z available)
        if z is not None:
            sel = sel0 & (z >= (float(l["z_lens"]) + args.min_zsep))
        else:
            sel = sel0

        if not np.any(sel):
            stacks.setdefault(key, {"sum_we_t": np.zeros(nb), "sum_w": np.zeros(nb), "nL": 0})
            stacks[key]["nL"] += 1
            continue

        # Tangential shear
        dx = dra[sel] * 3600.0  # arcsec
        dy = dde[sel] * 3600.0
        r = np.hypot(dx, dy)
        phi = np.arctan2(dy, dx)
        cos2, sin2 = np.cos(2 * phi), np.sin(2 * phi)
        e_t = -(e1[sel] * cos2 + e2[sel] * sin2)

        ww = w[sel]
        if args.use_m_corr and (m is not None):
            ww = ww / np.clip(1.0 + m[sel], 0.5, 2.0)

        # Bin by separation
        idx = np.digitize(r, b_edges) - 1
        good = (idx >= 0) & (idx < nb)
        if not np.any(good):
            stacks.setdefault(key, {"sum_we_t": np.zeros(nb), "sum_w": np.zeros(nb), "nL": 0})
            stacks[key]["nL"] += 1
            continue

        bins = idx[good]
        s_we = np.bincount(bins, weights=ww[good] * e_t[good], minlength=nb)[:nb]
        s_w  = np.bincount(bins, weights=ww[good], minlength=nb)[:nb]

        acc = stacks.setdefault(key, {"sum_we_t": np.zeros(nb), "sum_w": np.zeros(nb), "nL": 0})
        acc["sum_we_t"] += s_we
        acc["sum_w"]    += s_w
        acc["nL"]       += 1

    # Write outputs
    rows, meta = [], []
    for (rg_lab, ms_lab), acc in stacks.items():
        s_we, s_w, nL = acc["sum_we_t"], acc["sum_w"], acc["nL"]
        with np.errstate(divide='ignore', invalid='ignore'):
            gamma_t = np.where(s_w > 0, s_we / s_w, np.nan)
        for k, gt in enumerate(gamma_t):
            if not np.isfinite(gt):
                continue
            rows.append({
                "stack_id": f"{rg_lab}_{ms_lab}",
                "R_G_bin": rg_lab,
                "Mstar_bin": ms_lab,
                "b": float(b_mids[k]),       # arcsec midpoints
                "gamma_t": float(gt),
                "weight": float(s_w[k])
            })
        meta.append({
            "stack_id": f"{rg_lab}_{ms_lab}",
            "n_lenses": int(nL),
            "R_G_mean_kpc": np.nan
        })

    pd.DataFrame(rows).to_csv(args.out, index=False)
    pd.DataFrame(meta).to_csv(args.out_meta, index=False)
    print(f"Wrote {args.out} ({len(rows)} rows); meta {args.out_meta} ({len(meta)} stacks).")

if __name__ == "__main__":
    main()
